{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9c9cd486",
   "metadata": {},
   "source": [
    "# Wrangle MS-MARCO for Relevancy Classifications\n",
    "\n",
    "The purpose of this notebook is to both transform the MS-MARCO dataset into relevancy classification format and to document the quirks of the original data format. MS-MARCO is a benchmark dataset used to evaluated information retrieval systems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09abc4a-65a6-4c5e-94ab-6538f0c16a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1c33889-c950-41fc-b2a5-fa8b4a305ff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "split = \"train\"\n",
    "version = \"v1.1\"\n",
    "dataset = load_dataset(\"ms_marco\", version, split=split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80688937",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_df = dataset.to_pandas()\n",
    "raw_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ffbf239",
   "metadata": {},
   "source": [
    "The \"wellFormedAnswers\" column contains only empty lists. This column is omitted from the final dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62ce9e45",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_df[\"wellFormedAnswers\"].apply(len).value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86cbba10",
   "metadata": {},
   "source": [
    "Explode each row of the original dataset so that the new dataset contains one row per query-context pair."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbe246e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_texts = []\n",
    "query_ids = []\n",
    "query_types = []\n",
    "reference_responses = []\n",
    "selections = []\n",
    "document_texts = []\n",
    "document_urls = []\n",
    "for data in dataset:\n",
    "    document_data = data[\"passages\"]\n",
    "    selections_for_query = list(map(bool, document_data[\"is_selected\"]))\n",
    "    document_texts_for_query = document_data[\"passage_text\"]\n",
    "    document_urls_for_query = document_data[\"url\"]\n",
    "    assert (\n",
    "        len(selections_for_query) == len(document_texts_for_query) == len(document_urls_for_query)\n",
    "    )\n",
    "    num_documents_for_query = len(selections_for_query)\n",
    "    selections.extend(selections_for_query)\n",
    "    document_texts.extend(document_texts_for_query)\n",
    "    document_urls.extend(document_urls_for_query)\n",
    "    query_ids.extend([data[\"query_id\"]] * num_documents_for_query)\n",
    "    query_texts.extend([data[\"query\"]] * num_documents_for_query)\n",
    "    query_types.extend([data[\"query_type\"]] * num_documents_for_query)\n",
    "    reference_responses.extend([data[\"answers\"]] * num_documents_for_query)\n",
    "df = pd.DataFrame(\n",
    "    {\n",
    "        \"query_id\": query_ids,\n",
    "        \"query_text\": query_texts,\n",
    "        \"query_type\": query_types,\n",
    "        \"relevant\": selections,\n",
    "        \"document_text\": document_texts,\n",
    "        \"document_url\": document_urls,\n",
    "        \"reference_responses\": reference_responses,\n",
    "    }\n",
    ")\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be648e50",
   "metadata": {},
   "source": [
    "Compare the column names of the original dataset with the columns of the wrangled dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8e62c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(raw_df.columns).difference(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "314b3411",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(df.columns).difference(raw_df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc73dc6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "binary_relevance_classification_df = df[\n",
    "    [\"query_id\", \"query_text\", \"document_text\", \"document_url\", \"relevant\"]\n",
    "]\n",
    "binary_relevance_classification_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "068c1a59",
   "metadata": {},
   "source": [
    "Write the data to a JSONL file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a791a93",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = f\"ms_marco-{version}-{split}.jsonl\"\n",
    "with open(data_path, \"w\") as f:\n",
    "    for record in binary_relevance_classification_df.to_dict(orient=\"records\"):\n",
    "        f.write(json.dumps(record) + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
